import io

import torch
from ._utils import _type, _cuda
from torch.types import Storage
from typing import Any, TypeVar, Type, Union, cast
import copy
import collections
from functools import lru_cache

T = TypeVar('T', bound='Union[_StorageBase, TypedStorage]')
class _StorageBase(object):
    _cdata: Any
    is_cuda: bool = False
    is_sparse: bool = False
    device: torch.device

    def __init__(self, *args, **kwargs): ...  # noqa: E704
    def __len__(self) -> int: ...  # noqa: E704
    def __getitem__(self, idx): ...  # noqa: E704
    def copy_(self, source: T) -> T: ...  # noqa: E704
    def nbytes(self) -> int: ...  # noqa: E704

    def size(self) -> int:
        return self.nbytes()

    def type(self, dtype: str = None, non_blocking: bool = False) -> T: ...  # noqa: E704
    def cuda(self, device=None, non_blocking=False, **kwargs) -> T: ...  # noqa: E704
    def element_size(self) -> int: ...  # noqa: E704
    def get_device(self) -> int: ...  # noqa: E704
    def data_ptr(self) -> int: ...  # noqa: E704

    # Defined in torch/csrc/generic/StorageSharing.cpp
    def _share_filename_(self): ...  # noqa: E704
    def _share_fd_(self): ...  # noqa: E704
    @classmethod
    def _new_using_filename(cls: Type[T], size: int) -> T: ...  # noqa: E704
    @classmethod
    def _new_using_fd(cls: Type[T], size: int) -> T: ...  # noqa: E704

    def __str__(self):
        content = ' ' + '\n '.join(str(self[i]) for i in range(len(self)))
        return content + f'\n[{torch.typename(self)} of size {len(self)}]'

    def __repr__(self):
        return str(self)

    def __iter__(self):
        return iter(map(lambda i: self[i], range(self.size())))

    def __copy__(self):
        return self.clone()

    def __deepcopy__(self, memo):
        memo = memo.setdefault('torch', {})
        if self._cdata in memo:
            return memo[self._cdata]
        new_storage = self.clone()
        memo[self._cdata] = new_storage
        return new_storage

    def __reduce__(self):
        b = io.BytesIO()
        torch.save(self, b, _use_new_zipfile_serialization=False)
        return (_load_from_bytes, (b.getvalue(),))

    def __sizeof__(self):
        return super(_StorageBase, self).__sizeof__() + self.size()

    def clone(self):
        """Returns a copy of this storage"""
        device = self.get_device() if self.is_cuda else -1
        with torch.cuda.device(device):
            return type(self)(self.nbytes()).copy_(self)

    def tolist(self):
        """Returns a list containing the elements of this storage"""
        return list(self)

    def cpu(self):
        """Returns a CPU copy of this storage if it's not already on the CPU"""
        return _type(self, getattr(torch, self.__class__.__name__))

    def _to(self, dtype):
        storage = torch.tensor([], dtype=torch.uint8, device=self.device).set_(cast(Storage, self)).to(dtype).storage()
        if storage.data_ptr() == self.data_ptr():
            storage = storage.clone()
        return storage

    def double(self):
        """Casts this storage to double type"""
        return self._to(torch.double)

    def float(self):
        """Casts this storage to float type"""
        return self._to(torch.float)

    def half(self):
        """Casts this storage to half type"""
        return self._to(torch.half)

    def long(self):
        """Casts this storage to long type"""
        return self._to(torch.long)

    def int(self):
        """Casts this storage to int type"""
        return self._to(torch.int)

    def short(self):
        """Casts this storage to short type"""
        return self._to(torch.short)

    def char(self):
        """Casts this storage to char type"""
        return self._to(torch.int8)

    def byte(self):
        """Casts this storage to byte type"""
        return self._to(torch.uint8)

    def bool(self):
        """Casts this storage to bool type"""
        return self._to(torch.bool)

    def bfloat16(self):
        """Casts this storage to bfloat16 type"""
        return self._to(torch.bfloat16)

    def complex_double(self):
        """Casts this storage to complex double type"""
        return self._to(torch.cdouble)

    def complex_float(self):
        """Casts this storage to complex float type"""
        return self._to(torch.cfloat)

    def pin_memory(self):
        """Copies the storage to pinned memory, if it's not already pinned."""
        if self.is_cuda:
            raise TypeError(f"cannot pin '{self.type()}' only CPU memory can be pinned")
        import torch.cuda
        allocator = torch.cuda._host_allocator()  # type: ignore[attr-defined]
        return type(self)(self.size(), allocator=allocator).copy_(self)

    def share_memory_(self):
        """Moves the storage to shared memory.

        This is a no-op for storages already in shared memory and for CUDA
        storages, which do not need to be moved for sharing across processes.
        Storages in shared memory cannot be resized.

        Returns: self
        """
        from torch.multiprocessing import get_sharing_strategy
        if self.is_cuda:
            pass  # CUDA doesn't use POSIX shared memory
        elif get_sharing_strategy() == 'file_system':
            self._share_filename_()
        else:
            self._share_fd_()
        return self

    @classmethod
    def _new_shared(cls, size):
        """Creates a new storage in shared memory with the same data type"""
        from torch.multiprocessing import get_sharing_strategy
        if cls.is_cuda:
            return cls(size)
        elif get_sharing_strategy() == 'file_system':
            return cls._new_using_filename(size)
        else:
            return cls._new_using_fd(size)

    def _untyped(self):
        return self


def _load_from_bytes(b):
    return torch.load(io.BytesIO(b))


_StorageBase.type = _type  # type: ignore[assignment]
_StorageBase.cuda = _cuda  # type: ignore[assignment]


@lru_cache(maxsize=None)
def _dtype_to_storage_type_map():
    return {
        torch.double: 'DoubleStorage',
        torch.float: 'FloatStorage',
        torch.half: 'HalfStorage',
        torch.long: 'LongStorage',
        torch.int: 'IntStorage',
        torch.int16: 'ShortStorage',
        torch.int8: 'CharStorage',
        torch.uint8: 'ByteStorage',
        torch.bool: 'BoolStorage',
        torch.bfloat16: 'BFloat16Storage',
        torch.cdouble: 'ComplexDoubleStorage',
        torch.cfloat: 'ComplexFloatStorage',
        torch.qint8: 'QInt8Storage',
        torch.qint32: 'QInt32Storage',
        torch.quint8: 'QUInt8Storage',
        torch.quint4x2: 'QUInt4x2Storage',
        torch.quint2x4: 'QUInt2x4Storage',
    }

@lru_cache(maxsize=None)
def _storage_type_to_dtype_map():
    dtype_map = {
        val: key for key, val in _dtype_to_storage_type_map().items()}
    return dtype_map

class TypedStorage:
    is_sparse = False

    def fill_(self, value):
        self[0:len(self)] = value
        return self

    def __init__(self, *args, **kwargs):
        arg_error_msg = (
            f'{type(self)} constructor received an invalid combination '
            f'of arguments - got args={tuple(type(arg) for arg in args)}, '
            f'kwargs={ {key: type(val) for key, val in kwargs.items()} }, but '
            'expected one of:\n'
            ' * no arguments\n'
            ' * (int size)\n'
            ' * (Sequence data)\n')
        if type(self) == TypedStorage:
            arg_error_msg += ' * (wrap_storage=<UntypedStorage>, dtype=<torch.dtype>)'
        else:
            arg_error_msg += ' * (wrap_storage=<UntypedStorage>)'

        if 'wrap_storage' in kwargs:
            assert len(args) == 0, (
                "No positional arguments should be given when using "
                "'wrap_storage'")

            if type(self) == TypedStorage:
                assert 'dtype' in kwargs, (
                    "When using 'wrap_storage', 'dtype' also must be specified")
                assert len(kwargs) == 2, (
                    "Only 'wrap_storage' and 'dtype' should be given, but got: "
                    f"{kwargs}")
                dtype = kwargs['dtype']
                assert isinstance(dtype, torch.dtype)
                self.dtype = dtype

            else:
                assert hasattr(self, 'dtype')
                assert len(kwargs) == 1, (
                    f"Only 'wrap_storage' should be given, but got: {kwargs.keys()}")
                dtype = self.dtype

            storage = kwargs['wrap_storage']

            if not isinstance(storage, (torch.UntypedStorage, torch.cuda.UntypedStorage)):
                raise TypeError(arg_error_msg)
            if type(self) != TypedStorage and storage.__module__ != self.__module__:
                raise TypeError((
                    arg_error_msg +
                    f'\n`storage` `module {storage.__module__}` does not match '
                    f'module of {type(self)}'))
            self._storage = storage

        else:
            assert type(self) != TypedStorage, (
                "Calling __init__ this way is only supported in TypedStorage's "
                "child classes. TypedStorage can only be directly instantiated "
                "when kwargs 'wrap_storage' and 'dtype' are given.")

            assert len(kwargs) == 0, "invalid keyword arguments"

            def isint(x):
                try:
                    int(x)
                except TypeError:
                    return False
                return True

            if len(args) == 0:
                self._storage = eval(self.__module__).UntypedStorage()

            elif len(args) == 1 and isint(args[0]):
                self._storage = eval(self.__module__).UntypedStorage(int(args[0]) * self.element_size())

            elif len(args) == 1 and isinstance(args[0], collections.abc.Sequence):
                if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
                    interpret_dtypes = {
                        torch.quint8: torch.uint8,
                        torch.quint4x2: torch.uint8,
                        torch.quint2x4: torch.uint8,
                        torch.qint32: torch.int32,
                        torch.qint8: torch.int8
                    }
                    tmp_tensor = torch.tensor(
                        args[0],
                        dtype=interpret_dtypes[self.dtype],
                        device='cuda' if eval(self.__module__) is torch.cuda else 'cpu')

                else:
                    tmp_tensor = torch.tensor(
                        args[0],
                        dtype=self.dtype,
                        device='cuda' if eval(self.__module__) is torch.cuda else 'cpu')

                self._storage = tmp_tensor.storage()._untyped()

            else:
                raise TypeError(arg_error_msg)

    @property
    def is_cuda(self):
        return self._storage.device.type == 'cuda'

    def _untyped(self):
        return self._storage

    def _new_wrapped_storage(self, untyped_storage):
        module = eval(untyped_storage.__module__)
        assert type(untyped_storage) == module.UntypedStorage

        if type(self) == TypedStorage:
            return TypedStorage(wrap_storage=untyped_storage, dtype=self.dtype)
        else:
            # NOTE: We need to use the module of untyped_storage in case self's
            # module is different, e.g. if self is on CPU and untyped_storage
            # is on CUDA, and vice versa
            return getattr(module, type(self).__name__)(wrap_storage=untyped_storage)

    def __len__(self):
        return self._storage.nbytes() // self.element_size()

    def _maybe_wrap_index(self, idx, is_stop=False):
        if idx is None:
            if is_stop:
                return self.size()
            else:
                return 0

        else:
            if type(idx) != int:
                raise TypeError(
                    f"can't index a {type(self)} with {type(idx)}")
            if is_stop:
                if (idx > self.size()) or (idx < -self.size()):
                    raise IndexError(
                        f'index {idx} out of range for storage of size {self.size()}')
                if idx > 0:
                    return idx
                else:
                    return idx % self.size()
            else:
                if (idx >= self.size()) or (idx < -self.size()):
                    raise IndexError(
                        f'index {idx} out of range for storage of size {self.size()}')
                return idx % self.size()

    def __setitem__(self, idx, value):
        if not isinstance(idx, (int, slice)):
            raise RuntimeError(f"can't index a {type(self)} with {type(idx)}")
        if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
            interpret_dtypes = {
                torch.quint8: torch.uint8,
                torch.quint4x2: torch.uint8,
                torch.quint2x4: torch.uint8,
                torch.qint32: torch.int32,
                torch.qint8: torch.int8
            }
            tmp_dtype = interpret_dtypes[self.dtype]
            tmp_tensor = torch.tensor([], dtype=tmp_dtype, device=self.device).set_(TypedStorage(
                wrap_storage=self._storage,
                dtype=tmp_dtype))
        else:
            tmp_tensor = torch.tensor([], dtype=self.dtype, device=self.device).set_(self)

        tmp_tensor[idx] = value

    def __getitem__(self, idx):
        # NOTE: Before TypedStorage existed, indexing with a slice used to be
        # possible for <type>Storage objects. However, it would return
        # a storage view, which would be a hassle to implement in TypedStorage,
        # so it was disabled
        if isinstance(idx, slice):
            raise RuntimeError('slices are only supported in UntypedStorage.__getitem__')
        elif not isinstance(idx, int):
            raise RuntimeError(f"can't index a {type(self)} with {type(idx)}")

        if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
            interpret_dtypes = {
                torch.quint8: torch.uint8,
                torch.quint4x2: torch.uint8,
                torch.quint2x4: torch.uint8,
                torch.qint32: torch.int32,
                torch.qint8: torch.int8
            }
            return TypedStorage(
                wrap_storage=self._storage,
                dtype=interpret_dtypes[self.dtype])[idx]

        idx_wrapped = self._maybe_wrap_index(idx)
        tmp_tensor = torch.tensor([], dtype=self.dtype, device=self.device).set_(self)
        return tmp_tensor[idx_wrapped].item()

    def copy_(self, source: T, non_blocking=None):
        self._storage.copy_(source._untyped(), non_blocking)
        return self

    def nbytes(self):
        return self._storage.nbytes()

    def type(self, dtype: str = None, non_blocking: bool = False) -> Union[T, str]:
        if dtype is None:
            return '.'.join([self.__module__, type(self).__name__])
        else:
            return self._storage.type(dtype, non_blocking)

    def cuda(self, device=None, non_blocking=False, **kwargs) -> T:
        cuda_storage = self._storage.cuda(device, non_blocking, **kwargs)
        return self._new_wrapped_storage(cuda_storage)

    def element_size(self):
        return torch._utils._element_size(self.dtype)

    def get_device(self) -> int:
        return self._storage.get_device()

    def __str__(self):
        data_str = ' ' + '\n '.join(str(self[i]) for i in range(self.size()))
        if type(self) == TypedStorage:
            return data_str + (
                f'\n[{torch.typename(self)} with dtype {self.dtype} '
                f'of size {len(self)}]')
        else:
            return data_str + f'\n[{torch.typename(self)} of size {len(self)}]'

    def __repr__(self):
        return str(self)

    def __iter__(self):
        return iter(map(lambda i: self[i], range(self.size())))

    def __copy__(self):
        return self._new_wrapped_storage(copy.copy(self._storage))

    def __deepcopy__(self, memo):
        return self._new_wrapped_storage(copy.deepcopy(self._storage, memo))

    def __sizeof__(self):
        return super(TypedStorage, self).__sizeof__() + self.nbytes()

    def clone(self):
        """Returns a copy of this storage"""
        return self._new_wrapped_storage(self._storage.clone())

    def tolist(self):
        """Returns a list containing the elements of this storage"""
        return list(self)

    def cpu(self):
        """Returns a CPU copy of this storage if it's not already on the CPU"""
        return self._new_wrapped_storage(self._storage.cpu())

    def pin_memory(self):
        """Coppies the  storage to pinned memory, if it's not already pinned."""
        return self._new_wrapped_storage(self._storage.pin_memory())

    def share_memory_(self):
        """Moves the storage to shared memory.

        This is a no-op for storages already in shared memory and for CUDA
        storages, which do not need to be moved for sharing across processes.
        Storages in shared memory cannot be resized.

        Returns: self
        """
        self._storage.share_memory_()
        return self

    @classmethod
    def _new_shared(cls, size):
        """Creates a new storage in shared memory with the same data type"""
        module = eval(cls.__module__)
        untyped_storage = module.UntypedStorage._new_shared(size * cls().element_size())
        return cls(wrap_storage=untyped_storage)

    @property
    def _cdata(self):
        return self._storage._cdata

    @property
    def device(self):
        return self._storage.device

    def size(self):
        return len(self)

    def pickle_storage_type(self):
        try:
            return _dtype_to_storage_type_map()[self.dtype]
        except KeyError:
            raise KeyError(f'dtype {self.dtype} is not recognized')

    def __reduce__(self):
        b = io.BytesIO()
        torch.save(self, b, _use_new_zipfile_serialization=False)
        return (_load_from_bytes, (b.getvalue(),))

    def data_ptr(self):
        return self._storage.data_ptr()

    def resize_(self, size):
        self._storage.resize_(size * self.element_size())

    @classmethod
    def _free_weak_ref(cls, *args, **kwargs):
        return eval(cls.__module__).UntypedStorage._free_weak_ref(*args, **kwargs)

    def _weak_ref(self, *args, **kwargs):
        return self._storage._weak_ref(*args, **kwargs)

    @classmethod
    def from_buffer(cls, *args, **kwargs):
        if cls == TypedStorage:
            raise RuntimeError(
                'from_buffer: only supported for subclasses of TypedStorage')

        if 'dtype' in kwargs or len(args) == 5:
            raise RuntimeError((
                "from_buffer: 'dtype' can only be specified in "
                "UntypedStorage.from_buffer"))

        kwargs['dtype'] = cls().dtype

        untyped_storage = eval(cls.__module__).UntypedStorage.from_buffer(*args, **kwargs)
        return cls(wrap_storage=untyped_storage)

    def _to(self, dtype):
        storage = torch.tensor([], dtype=self.dtype, device=self.device).set_(self).to(dtype).storage()
        if storage.data_ptr() == self.data_ptr():
            storage = storage.clone()
        return storage

    def double(self):
        """Casts this storage to double type"""
        return self._to(torch.double)

    def float(self):
        """Casts this storage to float type"""
        return self._to(torch.float)

    def half(self):
        """Casts this storage to half type"""
        return self._to(torch.half)

    def long(self):
        """Casts this storage to long type"""
        return self._to(torch.long)

    def int(self):
        """Casts this storage to int type"""
        return self._to(torch.int)

    def short(self):
        """Casts this storage to short type"""
        return self._to(torch.short)

    def char(self):
        """Casts this storage to char type"""
        return self._to(torch.int8)

    def byte(self):
        """Casts this storage to byte type"""
        return self._to(torch.uint8)

    def bool(self):
        """Casts this storage to bool type"""
        return self._to(torch.bool)

    def bfloat16(self):
        """Casts this storage to bfloat16 type"""
        return self._to(torch.bfloat16)

    def complex_double(self):
        """Casts this storage to complex double type"""
        return self._to(torch.cdouble)

    def complex_float(self):
        """Casts this storage to complex float type"""
        return self._to(torch.cfloat)

    @classmethod
    def from_file(cls, filename, shared, size):
        if cls == TypedStorage:
            raise RuntimeError('from_file can only be called on derived classes')
        untyped_storage = eval(cls.__module__).UntypedStorage.from_file(
            filename,
            shared,
            size * torch._utils._element_size(cls.dtype))
        storage = cls(wrap_storage=untyped_storage)
        return storage

    @classmethod
    def _expired(cls, *args, **kwargs):
        return eval(cls.__module__).UntypedStorage._expired(*args, **kwargs)

    def is_pinned(self):
        return self._storage.is_pinned()

    def _write_file(self, *args, **kwargs):
        return self._storage._write_file(*args, **kwargs)

    def _set_from_file(self, *args, **kwargs):
        return self._storage._set_from_file(*args, **kwargs)

    def _set_cdata(self, *args, **kwargs):
        return self._storage._set_cdata(*args, **kwargs)

    def _share_cuda_(self, *args, **kwargs):
        return self._storage._share_cuda_(*args, **kwargs)

    def is_shared(self):
        return self._storage.is_shared()

    @classmethod
    def _new_shared_cuda(cls, *args, **kwargs):
        return eval(cls.__module__).UntypedStorage._new_shared_cuda(*args, **kwargs)

    @classmethod
    def _new_with_weak_ptr(cls, *args, **kwargs):
        return eval(cls.__module__).UntypedStorage._new_with_weak_ptr(*args, **kwargs)

    def _share_filename_(self, *args, **kwargs):
        manager_handle, storage_handle, size = self._storage._share_filename_(*args, **kwargs)
        return manager_handle, storage_handle, size // self.element_size()

    @classmethod
    def _new_shared_filename(cls, manager, obj, size):
        bytes_size = size * torch._utils._element_size(cls.dtype)
        return cls(wrap_storage=eval(cls.__module__).UntypedStorage._new_shared_filename(manager, obj, bytes_size))

    def _shared_decref(self):
        self._storage._shared_decref()
        return self

    @classmethod
    def _release_ipc_counter(cls, *args, **kwargs):
        return eval(cls.__module__).UntypedStorage._release_ipc_counter(*args, **kwargs)

    def _shared_incref(self, *args, **kwargs):
        return self._storage._shared_incref(*args, **kwargs)

    def _share_fd_(self, *args, **kwargs):
        fd, size = self._storage._share_fd_(*args, **kwargs)
        return fd, size // self.element_size()

def _get_dtype_from_pickle_storage_type(pickle_storage_type: str):
    try:
        return _storage_type_to_dtype_map()[pickle_storage_type]
    except KeyError:
        raise KeyError(
            f'pickle storage type "{pickle_storage_type}" is not recognized')
